//
// Created by DefTruth on 2021/11/21.
//

#include "tnn_face_landmarks_1000.h"

using tnncv::TNNFaceLandmark1000;

TNNFaceLandmark1000::TNNFaceLandmark1000(const std::string &_proto_path,
                                         const std::string &_model_path,
                                         unsigned int _num_threads) :
    BasicTNNHandler(_proto_path, _model_path, _num_threads)
{
}

void TNNFaceLandmark1000::transform(const cv::Mat &mat)
{
  cv::Mat mat_rs;
  cv::resize(mat, mat_rs, cv::Size(input_width, input_height));
  cv::cvtColor(mat_rs, mat_rs, cv::COLOR_BGR2GRAY);
  // push into input_mat
  input_mat = std::make_shared<tnn::Mat>(input_device_type, tnn::NGRAY,
                                         input_shape, (void *) mat_rs.data);
  if (!input_mat->GetData())
  {
#ifdef LITETNN_DEBUG
    std::cout << "input_mat == nullptr! transform failed\n";
#endif
  }
}

void TNNFaceLandmark1000::detect(const cv::Mat &mat, types::Landmarks &landmarks)
{
  if (mat.empty()) return;
  float img_height = static_cast<float>(mat.rows);
  float img_width = static_cast<float>(mat.cols);

  // 1. make input mat
  this->transform(mat);
  // 2. set input_mat
  tnn::MatConvertParam input_cvt_param;
  input_cvt_param.scale = scale_vals;
  input_cvt_param.bias = bias_vals;

  tnn::Status status;
  status = instance->SetInputMat(input_mat, input_cvt_param);
  if (status != tnn::TNN_OK)
  {
#ifdef LITETNN_DEBUG
    std::cout << status.description().c_str() << "\n";
#endif
    return;
  }
  // 3. forward
  status = instance->Forward();
  if (status != tnn::TNN_OK)
  {
#ifdef LITETNN_DEBUG
    std::cout << status.description().c_str() << "\n";
#endif
    return;
  }
  // 4. fetch landmarks.
  tnn::MatConvertParam cvt_param;
  std::shared_ptr<tnn::Mat> landmarks_norm; // (1,1953)
  status = instance->GetOutputMat(landmarks_norm, cvt_param, "output0", output_device_type);

  if (status != tnn::TNN_OK)
  {
#ifdef LITETNN_DEBUG
    std::cout << status.description().c_str() << "\n";
#endif
    return;
  }
  auto landmark_dims = landmarks_norm->GetDims();

  unsigned int num_landmarks = landmark_dims.at(1);
  if (num_landmarks > 1946) num_landmarks = 1946;

  const float *landmarks_ptr = (float *) landmarks_norm->GetData();

  for (unsigned int i = 0; i < num_landmarks; i += 2)
  {
    float x = landmarks_ptr[i];
    float y = landmarks_ptr[i + 1];

    x = std::min(std::max(0.f, x), 1.0f);
    y = std::min(std::max(0.f, y), 1.0f);

    landmarks.points.push_back(cv::Point2f(x * img_width, y * img_height));
  }
  landmarks.flag = true;
}
